import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional, Union

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
    BaseChatModel,
    generate_from_stream,
)
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
    BaseMessageChunk,
    ChatMessage,
    ChatMessageChunk,
    HumanMessage,
    HumanMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import (
    convert_to_secret_str,
    get_from_dict_or_env,
)
from pydantic import ConfigDict, Field, SecretStr, model_validator

logger = logging.getLogger(__name__)

DEFAULT_API_BASE = "https://api.coze.com"


def _convert_message_to_dict(message: BaseMessage) -> dict:
    message_dict: Dict[str, Any]
    if isinstance(message, HumanMessage):
        message_dict = {
            "role": "user",
            "content": message.content,
            "content_type": "text",
        }
    else:
        message_dict = {
            "role": "assistant",
            "content": message.content,
            "content_type": "text",
        }
    return message_dict


def _convert_dict_to_message(_dict: Mapping[str, Any]) -> Union[BaseMessage, None]:
    msg_type = _dict["type"]
    if msg_type != "answer":
        return None
    role = _dict["role"]
    if role == "user":
        return HumanMessage(content=_dict["content"])
    elif role == "assistant":
        return AIMessage(content=_dict.get("content", "") or "")
    else:
        return ChatMessage(content=_dict["content"], role=role)


def _convert_delta_to_message_chunk(_dict: Mapping[str, Any]) -> BaseMessageChunk:
    role = _dict.get("role")
    content = _dict.get("content") or ""

    if role == "user":
        return HumanMessageChunk(content=content)
    elif role == "assistant":
        return AIMessageChunk(content=content)
    else:
        return ChatMessageChunk(content=content, role=role)  # type: ignore[arg-type]


class ChatCoze(BaseChatModel):
    """ChatCoze chat models API by coze.com

    For more information, see https://www.coze.com/open/docs/chat
    """

    @property
    def lc_secrets(self) -> Dict[str, str]:
        return {
            "coze_api_key": "COZE_API_KEY",
        }

    @property
    def lc_serializable(self) -> bool:
        return True

    coze_api_base: str = Field(default=DEFAULT_API_BASE)
    """Coze custom endpoints"""
    coze_api_key: Optional[SecretStr] = None
    """Coze API Key"""
    request_timeout: int = Field(default=60, alias="timeout")
    """request timeout for chat http requests"""
    bot_id: str = Field(default="")
    """The ID of the bot that the API interacts with."""
    conversation_id: str = Field(default="")
    """Indicate which conversation the dialog is taking place in. If there is no need to
    distinguish the context of the conversation(just a question and answer), skip this
    parameter. It will be generated by the system."""
    user: str = Field(default="")
    """The user who calls the API to chat with the bot."""
    streaming: bool = False
    """Whether to stream the response to the client. 
    false: if no value is specified or set to false, a non-streaming response is
    returned. "Non-streaming response" means that all responses will be returned at once
    after they are all ready, and the client does not need to concatenate the content.
    true: set to true, partial message deltas will be sent .
    "Streaming response" will provide real-time response of the model to the client, and
    the client needs to assemble the final reply based on the type of message. """

    model_config = ConfigDict(
        populate_by_name=True,
    )

    @model_validator(mode="before")
    @classmethod
    def validate_environment(cls, values: Dict) -> Any:
        values["coze_api_base"] = get_from_dict_or_env(
            values,
            "coze_api_base",
            "COZE_API_BASE",
            DEFAULT_API_BASE,
        )
        values["coze_api_key"] = convert_to_secret_str(
            get_from_dict_or_env(
                values,
                "coze_api_key",
                "COZE_API_KEY",
            )
        )

        return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling Coze API."""
        return {
            "bot_id": self.bot_id,
            "conversation_id": self.conversation_id,
            "user": self.user,
            "streaming": self.streaming,
        }

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if self.streaming:
            stream_iter = self._stream(
                messages=messages, stop=stop, run_manager=run_manager, **kwargs
            )
            return generate_from_stream(stream_iter)

        r = self._chat(messages, **kwargs)
        res = r.json()
        if res["code"] != 0:
            raise ValueError(
                f"Error from Coze api response: {res['code']}: {res['msg']}, "
                f"logid: {r.headers.get('X-Tt-Logid')}"
            )

        return self._create_chat_result(res.get("messages") or [])

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        res = self._chat(messages, **kwargs)
        for chunk in res.iter_lines():
            chunk = chunk.decode("utf-8").strip("\r\n")
            parts = chunk.split("data:", 1)
            chunk = parts[1] if len(parts) > 1 else None
            if chunk is None:
                continue
            response = json.loads(chunk)
            if response["event"] == "done":
                break
            elif (
                response["event"] != "message"
                or response["message"]["type"] != "answer"
            ):
                continue
            chunk = _convert_delta_to_message_chunk(response["message"])
            cg_chunk = ChatGenerationChunk(message=chunk)
            if run_manager:
                run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
            yield cg_chunk

    def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
        parameters = {**self._default_params, **kwargs}

        query = ""
        chat_history = []
        for msg in messages:
            if isinstance(msg, HumanMessage):
                query = f"{msg.content}"  # overwrite, to get last user message as query
            chat_history.append(_convert_message_to_dict(msg))

        conversation_id = parameters.pop("conversation_id")
        bot_id = parameters.pop("bot_id")
        user = parameters.pop("user")
        streaming = parameters.pop("streaming")

        payload = {
            "conversation_id": conversation_id,
            "bot_id": bot_id,
            "user": user,
            "query": query,
            "stream": streaming,
        }
        if chat_history:
            payload["chat_history"] = chat_history

        url = self.coze_api_base + "/open_api/v2/chat"
        api_key = ""
        if self.coze_api_key:
            api_key = self.coze_api_key.get_secret_value()

        res = requests.post(
            url=url,
            timeout=self.request_timeout,
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {api_key}",
            },
            json=payload,
            stream=streaming,
        )
        if res.status_code != 200:
            logid = res.headers.get("X-Tt-Logid")
            raise ValueError(f"Error from Coze api response: {res}, logid: {logid}")
        return res

    def _create_chat_result(self, messages: List[Mapping[str, Any]]) -> ChatResult:
        generations = []
        for c in messages:
            msg = _convert_dict_to_message(c)
            if msg:
                generations.append(ChatGeneration(message=msg))

        llm_output = {"token_usage": "", "model": ""}
        return ChatResult(generations=generations, llm_output=llm_output)

    @property
    def _llm_type(self) -> str:
        return "coze-chat"
